{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "53c0d4d6-3d2b-45e5-90fa-ba7953496ec2",
   "metadata": {},
   "source": [
    "# LLM Evaluation and Tracing with W&B\n",
    "\n",
    "<!--- @wandbcode{dlai_04} -->\n",
    "\n",
    "## 1. Using Tables for Evaluation\n",
    "\n",
    "In this section, we will call OpenAI LLM to generate names of our game assets. We will use W&B Tables to evaluate the generations. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6512739b-fe35-4901-acb3-05df46b5ed9c",
   "metadata": {
    "height": 250
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import time\n",
    "import datetime\n",
    "\n",
    "import openai\n",
    "\n",
    "from tenacity import (\n",
    "    retry,\n",
    "    stop_after_attempt,\n",
    "    wait_random_exponential, # for exponential backoff\n",
    ")  \n",
    "import wandb\n",
    "from wandb.sdk.data_types.trace_tree import Trace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ab496d3-4ea5-4546-a924-2a6b7082b105",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "# get openai API key\n",
    "import os\n",
    "import openai\n",
    "import sys\n",
    "sys.path.append('../..')\n",
    "\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "_ = load_dotenv(find_dotenv()) # read local .env file\n",
    "\n",
    "openai.api_key  = os.environ['OPENAI_API_KEY']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83639bac-5860-4db1-9867-7c89f3ca25a6",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "PROJECT = \"dlai_llm\"\n",
    "MODEL_NAME = \"gpt-3.5-turbo\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb575380",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "wandb.login(anonymous=\"allow\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c304c2b-dcd8-463c-aba4-aa47094dc16b",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "run = wandb.init(project=PROJECT, job_type=\"generation\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e7bcf11",
   "metadata": {},
   "source": [
    "### Simple generations\n",
    "Let's start by generating names for our game assets using OpenAI `ChatCompletion`, and saving the resulting generations in W&B Tables. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2ab394b-295b-4cfa-aade-aa274003a56a",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))\n",
    "def completion_with_backoff(**kwargs):\n",
    "    return openai.ChatCompletion.create(**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "736fe64f-5cca-4316-8842-588b948193de",
   "metadata": {
    "height": 454
   },
   "outputs": [],
   "source": [
    "def generate_and_print(system_prompt, user_prompt, table, n=5):\n",
    "    messages=[\n",
    "            {\"role\": \"system\", \"content\": system_prompt},\n",
    "            {\"role\": \"user\", \"content\": user_prompt},\n",
    "        ]\n",
    "    start_time = time.time()\n",
    "    responses = completion_with_backoff(\n",
    "        model=MODEL_NAME,\n",
    "        messages=messages,\n",
    "        n = n,\n",
    "        )\n",
    "    elapsed_time = time.time() - start_time\n",
    "    for response in responses.choices:\n",
    "        generation = response.message.content\n",
    "        print(generation)\n",
    "    table.add_data(system_prompt,\n",
    "                user_prompt,\n",
    "                [response.message.content for response in responses.choices],\n",
    "                elapsed_time,\n",
    "                datetime.datetime.fromtimestamp(responses.created),\n",
    "                responses.model,\n",
    "                responses.usage.prompt_tokens,\n",
    "                responses.usage.completion_tokens,\n",
    "                responses.usage.total_tokens\n",
    "                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "690e6e0a-193b-41c8-86c4-526f8061dd94",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "system_prompt = \"\"\"You are a creative copywriter.\n",
    "You're given a category of game asset, \\\n",
    "and your goal is to design a name of that asset.\n",
    "The game is set in a fantasy world \\\n",
    "where everyone laughs and respects each other, \n",
    "while celebrating diversity.\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "395880fa",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# Define W&B Table to store generations\n",
    "columns = [\"system_prompt\", \"user_prompt\", \"generations\", \"elapsed_time\", \"timestamp\",\\\n",
    "            \"model\", \"prompt_tokens\", \"completion_tokens\", \"total_tokens\"]\n",
    "table = wandb.Table(columns=columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fb07587",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "user_prompt = \"hero\"\n",
    "generate_and_print(system_prompt, user_prompt, table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8343121b-2d47-47d1-b343-ec2393b8f02f",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "user_prompt = \"jewel\"\n",
    "generate_and_print(system_prompt, user_prompt, table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3266487e-150b-4dd8-9555-94e94a66aac1",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "wandb.log({\"simple_generations\": table})\n",
    "run.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16d6d513-389d-4c67-a942-a922bce6ff1a",
   "metadata": {},
   "source": [
    "## 2. Using Tracer to log more complex chains\n",
    "\n",
    "How can we get more creative outputs? Let's design an LLM chain that will first randomly pick a fantasy world, and then generate character names. We will demonstrate how to use Tracer in such scenario. We will log the inputs and outputs, start and end times, whether the OpenAI call was successful, the token usage, and additional metadata."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c9fd404-51fd-44cf-b41e-b81dc589a4af",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "worlds = [\n",
    "    \"a mystic medieval island inhabited by intelligent and funny frogs\",\n",
    "    \"a modern castle sitting on top of a volcano in a faraway galaxy\",\n",
    "    \"a digital world inhabited by friendly machine learning engineers\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0db1e20a-87a8-4386-9a8d-727db9569cd7",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "# define your config\n",
    "model_name = \"gpt-3.5-turbo\"\n",
    "temperature = 0.7\n",
    "system_message = \"\"\"You are a creative copywriter. \n",
    "You're given a category of game asset and a fantasy world.\n",
    "Your goal is to design a name of that asset.\n",
    "Provide the resulting name only, no additional description.\n",
    "Single name, max 3 words output, remember!\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a86f95e-ed0d-4989-8c1d-5b88cdac7999",
   "metadata": {
    "height": 1355
   },
   "outputs": [],
   "source": [
    "def run_creative_chain(query):\n",
    "    # part 1 - a chain is started...\n",
    "    start_time_ms = round(datetime.datetime.now().timestamp() * 1000)\n",
    "\n",
    "    root_span = Trace(\n",
    "          name=\"MyCreativeChain\",\n",
    "          kind=\"chain\",\n",
    "          start_time_ms=start_time_ms,\n",
    "          metadata={\"user\": \"student_1\"},\n",
    "          model_dict={\"_kind\": \"CreativeChain\"}\n",
    "          )\n",
    "\n",
    "    # part 2 - your chain picks a fantasy world\n",
    "    time.sleep(3)\n",
    "    world = random.choice(worlds)\n",
    "    expanded_prompt = f'Game asset category: {query}; fantasy world description: {world}'\n",
    "    tool_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)\n",
    "\n",
    "    # create a Tool span \n",
    "    tool_span = Trace(\n",
    "          name=\"WorldPicker\",\n",
    "          kind=\"tool\",\n",
    "          status_code=\"success\",\n",
    "          start_time_ms=start_time_ms,\n",
    "          end_time_ms=tool_end_time_ms,\n",
    "          inputs={\"input\": query},\n",
    "          outputs={\"result\": expanded_prompt},\n",
    "          model_dict={\"_kind\": \"tool\", \"num_worlds\": len(worlds)}\n",
    "          )\n",
    "\n",
    "    # add the TOOL span as a child of the root\n",
    "    root_span.add_child(tool_span)\n",
    "\n",
    "    # part 3 - the LLMChain calls an OpenAI LLM...\n",
    "    messages=[\n",
    "      {\"role\": \"system\", \"content\": system_message},\n",
    "      {\"role\": \"user\", \"content\": expanded_prompt}\n",
    "    ]\n",
    "\n",
    "    response = completion_with_backoff(model=model_name,\n",
    "                                       messages=messages,\n",
    "                                       max_tokens=12,\n",
    "                                       temperature=temperature)   \n",
    "\n",
    "    llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)\n",
    "    response_text = response[\"choices\"][0][\"message\"][\"content\"]\n",
    "    token_usage = response[\"usage\"].to_dict()\n",
    "\n",
    "    llm_span = Trace(\n",
    "          name=\"OpenAI\",\n",
    "          kind=\"llm\",\n",
    "          status_code=\"success\",\n",
    "          metadata={\"temperature\":temperature,\n",
    "                    \"token_usage\": token_usage, \n",
    "                    \"model_name\":model_name},\n",
    "          start_time_ms=tool_end_time_ms,\n",
    "          end_time_ms=llm_end_time_ms,\n",
    "          inputs={\"system_prompt\":system_message, \"query\":expanded_prompt},\n",
    "          outputs={\"response\": response_text},\n",
    "          model_dict={\"_kind\": \"Openai\", \"engine\": response[\"model\"], \"model\": response[\"object\"]}\n",
    "          )\n",
    "\n",
    "    # add the LLM span as a child of the Chain span...\n",
    "    root_span.add_child(llm_span)\n",
    "\n",
    "    # update the end time of the Chain span\n",
    "    root_span.add_inputs_and_outputs(\n",
    "          inputs={\"query\":query},\n",
    "          outputs={\"response\": response_text})\n",
    "\n",
    "    # update the Chain span's end time\n",
    "    root_span.end_time_ms = llm_end_time_ms\n",
    "\n",
    "\n",
    "    # part 4 - log all spans to W&B by logging the root span\n",
    "    root_span.log(name=\"creative_trace\")\n",
    "    print(f\"Result: {response_text}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8500843-6d4b-4fc6-93b9-4cadf5813e4a",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Let's start a new wandb run\n",
    "wandb.init(project=PROJECT, job_type=\"generation\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7409a004",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "run_creative_chain(\"hero\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "538d7bf3-4ae1-4b57-8a96-a34ea0614ec3",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "run_creative_chain(\"jewel\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45de1fb0-3630-4673-8ac0-0dffe0a52071",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ccc075f-32bf-4451-b7ad-ab2a49cc86b6",
   "metadata": {},
   "source": [
    "## Langchain agent\n",
    "\n",
    "In the third scenario, we'll introduce an agent that will use tools such as WorldPicker and NameValidator to come up with the ultimate name. We will also use Langchain here and demonstrate its W&B integration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "726e0a6a-699b-434d-8c51-7542b4f981dd",
   "metadata": {
    "height": 199
   },
   "outputs": [],
   "source": [
    "# Import things that are needed generically\n",
    "from langchain.agents import AgentType, initialize_agent\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.tools import BaseTool\n",
    "\n",
    "from typing import Optional\n",
    "\n",
    "from langchain.callbacks.manager import (\n",
    "    AsyncCallbackManagerForToolRun,\n",
    "    CallbackManagerForToolRun,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5738431a-e281-4abf-9837-44fec6811ff4",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "wandb.init(project=PROJECT, job_type=\"generation\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac08f78b-0962-4d84-b39a-21ee5e5d606b",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "os.environ[\"LANGCHAIN_WANDB_TRACING\"] = \"true\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "539bc081-d1e3-4376-a817-23aa1d7ab2b3",
   "metadata": {
    "height": 726
   },
   "outputs": [],
   "source": [
    "class WorldPickerTool(BaseTool):\n",
    "    name = \"pick_world\"\n",
    "    description = \"pick a virtual game world for your character or item naming\"\n",
    "    worlds = [\n",
    "                \"a mystic medieval island inhabited by intelligent and funny frogs\",\n",
    "                \"a modern anthill featuring a cyber-ant queen and her cyber-ant-workers\",\n",
    "                \"a digital world inhabited by friendly machine learning engineers\"\n",
    "            ]\n",
    "\n",
    "    def _run(\n",
    "        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None\n",
    "    ) -> str:\n",
    "        \"\"\"Use the tool.\"\"\"\n",
    "        time.sleep(1)\n",
    "        return random.choice(self.worlds)\n",
    "\n",
    "    async def _arun(\n",
    "        self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None\n",
    "    ) -> str:\n",
    "        \"\"\"Use the tool asynchronously.\"\"\"\n",
    "        raise NotImplementedError(\"pick_world does not support async\")\n",
    "        \n",
    "class NameValidatorTool(BaseTool):\n",
    "    name = \"validate_name\"\n",
    "    description = \"validate if the name is properly generated\"\n",
    "\n",
    "    def _run(\n",
    "        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None\n",
    "    ) -> str:\n",
    "        \"\"\"Use the tool.\"\"\"\n",
    "        time.sleep(1)\n",
    "        if len(query) < 20:\n",
    "            return f\"This is a correct name: {query}\"\n",
    "        else:\n",
    "            return f\"This name is too long. It should be shorter than 20 characters.\"\n",
    "\n",
    "    async def _arun(\n",
    "        self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None\n",
    "    ) -> str:\n",
    "        \"\"\"Use the tool asynchronously.\"\"\"\n",
    "        raise NotImplementedError(\"validate_name does not support async\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c515ee33-1d6f-47e7-aceb-845c363eee29",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "llm = ChatOpenAI(temperature=0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "989407f4-0e10-4446-90d1-992c3b4c9483",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "tools = [WorldPickerTool(), NameValidatorTool()]\n",
    "agent = initialize_agent(\n",
    "    tools, \n",
    "    llm, \n",
    "    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
    "    handle_parsing_errors=True,\n",
    "    verbose=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d4bd42d-9c95-4e02-8679-99ca43d0aa71",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "agent.run(\n",
    "    \"Find a virtual game world for me and imagine the name of a hero in that world\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbb5ea87-a9b9-462f-80bf-b56d681dec8c",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "agent.run(\n",
    "    \"Find a virtual game world for me and imagine the name of a jewel in that world\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d101fcd-cd7d-4ede-ad95-412c1cd72e46",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "agent.run(\n",
    "    \"Find a virtual game world for me and imagine the name of food in that world.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "486c688c-2ca2-4fe5-8f22-afd194b3e34d",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "093d8262-e063-4b76-a615-7cb2d81fa071",
   "metadata": {},
   "source": [
    "**Note**: LLM outputs are variable. Your results may not match those in the video."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "643f6295",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93462bd0",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
